from __future__ import absolute_import, division, print_function, unicode_literals

import torch
import torch.fx
import pydot
from typing import Dict, Any
from torch.fx.node import _get_qualified_name

_COLOR_MAP = {
    "placeholder": '"AliceBlue"',
    "call_module": "LemonChiffon1",
    "call_function": "PeachPuff1",
    "get_param": "Yellow2",
    "get_attr": "LightGrey",
    "call_method": "LavenderBlush1",
    "output": "PowderBlue",
}

_WEIGHT_TEMPLATE = {
    "shape": "record",
    "fillcolor": "Salmon",
    "style": '"filled,rounded"',
    "fontcolor": "#000000",
}


class FxGraphDrawer:
    """
    Visualize a torch.fx.Graph with graphviz
    Basic usage:
        g = FxGraphDrawer(symbolic_traced, "resnet18")
        with open("a.svg", "w") as f:
            f.write(g.get_dot_graph().create_svg())
    """

    def __init__(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool = False):
        self._name = name
        self._dot_graphs = {name: self._to_dot(graph_module, name, ignore_getattr)}

        for node in graph_module.graph.nodes:
            if node.op != "call_module":
                continue

            leaf_node = self._get_leaf_node(graph_module, node)

            if not isinstance(leaf_node, torch.fx.GraphModule):
                continue

            self._dot_graphs[f"{name}_{node.target}"] = self._to_dot(leaf_node, f"{name}_{node.target}", ignore_getattr)

    def get_main_dot_graph(self) -> pydot.Dot:
        return self._dot_graphs[self._name]

    def get_submod_dot_graph(self, submod_name) -> pydot.Dot:
        return self._dot_graphs[f"{self._name}_{submod_name}"]

    def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]:
        return self._dot_graphs

    def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]:
        template = {
            "shape": "record",
            "fillcolor": "#CAFFE3",
            "style": '"filled,rounded"',
            "fontcolor": "#000000",
        }
        template["fillcolor"] = _COLOR_MAP[node.op]
        return template

    def _get_leaf_node(
        self, module: torch.nn.Module, node: torch.fx.Node
    ) -> torch.nn.Module:
        py_obj = module
        assert isinstance(node.target, str)
        atoms = node.target.split(".")
        for atom in atoms:
            if not hasattr(py_obj, atom):
                raise RuntimeError(
                    str(py_obj) + " does not have attribute " + atom + "!"
                )
            py_obj = getattr(py_obj, atom)
        return py_obj

    def _typename(self, target: Any) -> str:
        if isinstance(target, torch.nn.Module):
            return torch.typename(target)

        if isinstance(target, str):
            return target

        return _get_qualified_name(target)

    def _get_node_label(self, module: torch.fx.GraphModule, node: torch.fx.Node) -> str:
        label = "{" + f"{node.name}|op_code={node.op}"

        if node.op == "call_module":
            leaf_module = self._get_leaf_node(module, node)
            label += r"\l" + self._typename(leaf_module) + r"\l|"
            extra = ""
            if hasattr(leaf_module, "__constants__"):
                extra = r"\l".join(
                    [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__]  # type: ignore[union-attr]
                )
            label += extra + r"\l"
        else:
            label += "|" + self._typename(node.target) + r"\l"

        tensor_meta = node.meta.get('tensor_meta')
        if tensor_meta:
            dtype_ = tensor_meta.dtype if hasattr(tensor_meta, "dtype") else "none"
            shape_ = tensor_meta.shape if hasattr(tensor_meta, "shape") else "none"
            stride_ = tensor_meta.stride if hasattr(tensor_meta, "stride") else "none"
            if dtype_:
                label += "|" + "dtype" + "=" + str(dtype_) + r"\l"
            if shape_:
                label += "|" + "shape" + "=" + str(shape_) + r"\l"
            if stride_:
                label += "|" + "stride" + "=" + str(stride_) + r"\l"

        return label + "}"

    def _get_tensor_label(self, t: torch.Tensor) -> str:
        return str(t.dtype) + str(list(t.shape)) + r"\l"

    def _to_dot(self, graph_module: torch.fx.GraphModule, name: str, ignore_getattr: bool) -> pydot.Dot:
        """
        Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph
        """
        dot_graph = pydot.Dot(name, rankdir="TB")

        for node in graph_module.graph.nodes:
            if ignore_getattr and node.op == "get_attr":
                continue

            style = self._get_node_style(node)
            dot_node = pydot.Node(
                node.name, label=self._get_node_label(graph_module, node), **style
            )
            dot_graph.add_node(dot_node)

            def get_module_params_or_buffers(is_param: bool):
                for pname, ptensor in (
                    leaf_module.named_parameters()
                    if is_param
                    else leaf_module.named_buffers()
                ):
                    pname1 = node.name + "." + pname
                    label1 = (
                        pname1 + "|op_code=get_" + "parameter"
                        if is_param
                        else "buffer" + r"\l"
                    )
                    dot_w_node = pydot.Node(
                        pname1,
                        label="{" + label1 + self._get_tensor_label(ptensor) + "}",
                        **_WEIGHT_TEMPLATE,
                    )
                    dot_graph.add_node(dot_w_node)
                    dot_graph.add_edge(pydot.Edge(pname1, node.name))

            if node.op == "call_module":
                leaf_module = self._get_leaf_node(graph_module, node)

                if not isinstance(leaf_module, torch.fx.GraphModule):
                    get_module_params_or_buffers(True)
                    get_module_params_or_buffers(False)

        for node in graph_module.graph.nodes:
            if ignore_getattr and node.op == "get_attr":
                continue

            for user in node.users:
                dot_graph.add_edge(pydot.Edge(node.name, user.name))

        return dot_graph
